#from test import quan_Linear
from scipy.sparse import base
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from setbitnumber import setBitNumber
from hamming import solve
import torch.nn.init as init
from adversarialbox.utils import to_var, test

from sklearn.model_selection import train_test_split
BATCHNORM = True
torch.autograd.set_detect_anomaly(True) 

def _weights_init(m):
	classname = m.__class__.__name__
	#print(classname)
	if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
		init.kaiming_normal_(m.weight)
class LambdaLayer(nn.Module):
	def __init__(self, lambd):
		super(LambdaLayer, self).__init__()
		self.lambd = lambd

	def forward(self, x):
		return self.lambd(x)
## normalize layer
class Normalize_layer(nn.Module):
	
	def __init__(self, mean, std):
		super(Normalize_layer, self).__init__()
		self.mean = nn.Parameter(torch.Tensor(mean).unsqueeze(1).unsqueeze(1), requires_grad=False)
		self.std = nn.Parameter(torch.Tensor(std).unsqueeze(1).unsqueeze(1), requires_grad=False)
		
	def forward(self, input):
		
		return input.sub(self.mean).div(self.std)
#quantization function
class _inv_Quantize(torch.autograd.Function):
	@staticmethod
	def forward(ctx, input, step_size, half_lvls):
		# ctx is a context object that can be used to stash information
		# for backward computation
		input =torch.tensor([input]).cuda()
		ctx.step_size = step_size
		ctx.half_lvls = half_lvls
		output = F.hardtanh(input,
							min_val=-ctx.half_lvls * ctx.step_size.item(),
							max_val=ctx.half_lvls * ctx.step_size.item())

		output = output * ctx.step_size
		return output

	@staticmethod
	def backward(ctx, grad_output):
		grad_input = grad_output.clone() * ctx.step_size

		return grad_input, None, None
		
#quantization function
class _Quantize(torch.autograd.Function):
	@staticmethod
	def forward(ctx, input, step_size, half_lvls):
		# ctx is a context object that can be used to stash information
		# for backward computation
		ctx.step_size = step_size
		ctx.half_lvls = half_lvls
		output = F.hardtanh(input,
							min_val=-ctx.half_lvls * ctx.step_size.item(),
							max_val=ctx.half_lvls * ctx.step_size.item())

		output = torch.round(output / ctx.step_size)
		return output

	@staticmethod
	def backward(ctx, grad_output):
		grad_input = grad_output.clone() / ctx.step_size

		return grad_input, None, None

		
inv_quantize1 = _inv_Quantize.apply
	
quantize1 = _Quantize.apply

class quantized_conv(nn.Conv2d):
	def __init__(self,
				 in_channels,
				 out_channels,
				 kernel_size,
				 stride=1,
				 padding=0,
				 dilation=1,
				 groups=1,
				 bias=True):
		super(quantized_conv, self).__init__(in_channels,
										  out_channels,
										  kernel_size,
										  stride=stride,
										  padding=padding,
										  dilation=dilation,
										  groups=groups,
										  bias=bias)

		N_bits = 8
		full_lvls = 2**N_bits
		self.half_lvls = (full_lvls - 2) / 2
		# Initialize the step size
		self.step_size = nn.Parameter(torch.Tensor([1]), requires_grad=True)
	
	def forward(self, input):
		# flag to enable the inference with quantized weight or self.weight
		inf_with_weight = False  # disabled by default
		N_bits = 8
		full_lvls = 2**N_bits
		# Initialize the step size
		self.step_size = nn.Parameter(torch.Tensor([1]), requires_grad=True)
		with torch.no_grad():
			self.step_size.data = self.weight.abs().max() / self.half_lvls

		if inf_with_weight:
			return F.conv2d(input, self.weight * self.step_size, self.bias,
							self.stride, self.padding, self.dilation,
							self.groups)
		else:

			weight_quan = quantize1(self.weight, self.step_size,
								   self.half_lvls) * self.step_size
			return F.conv2d(input, weight_quan, self.bias, self.stride,
							self.padding, self.dilation, self.groups)

	def __reset_stepsize__(self):
		with torch.no_grad():
			self.step_size.data = self.weight.abs().max()/self.half_lvls

	def __reset_weight__(self):
		'''
		This function will reconstruct the weight stored in self.weight.
		Replacing the orginal floating-point with the quantized fix-point
		weight representation.
		'''
		# replace the weight with the quantized version
		with torch.no_grad():
			self.weight.data = quantize1(
				self.weight, self.step_size, self.half_lvls)
		# enable the flag, thus now computation does not invovle weight quantization
		self.inf_with_weight = True


class bilinear(nn.Linear):
	def __init__(self, in_features, out_features, bias=True):
		super(bilinear, self).__init__(in_features, out_features, bias=bias)
		N_bits = 8
		full_lvls = 2**N_bits
		self.half_lvls = (full_lvls - 2) / 2
		self.step_size = nn.Parameter(torch.Tensor([1]), requires_grad=True)

	def forward(self, input):
		N_bits = 8
		full_lvls = 2**N_bits
		# Initialize the step size
		self.step_size = nn.Parameter(torch.Tensor([1]), requires_grad=True)
		with torch.no_grad():
			self.step_size.data = self.weight.abs().max() / self.half_lvls

		# flag to enable the inference with quantized weight or self.weight
		inf_with_weight = False  # disabled by default

		if inf_with_weight:
			return F.linear(input, self.weight * self.step_size, self.bias)
		else:
			weight_quan = quantize1(self.weight, self.step_size,
								   self.half_lvls) * self.step_size
			return F.linear(input, weight_quan, self.bias)
	def __reset_stepsize__(self):
		with torch.no_grad():
			self.step_size.data = self.weight.abs().max()/self.half_lvls

	def __reset_weight__(self):
		'''
		This function will reconstruct the weight stored in self.weight.
		Replacing the orginal floating-point with the quantized fix-point
		weight representation.
		'''
		# replace the weight with the quantized version
		with torch.no_grad():
			self.weight.data = quantize1(
				self.weight, self.step_size, self.half_lvls)
		# enable the flag, thus now computation does not invovle weight quantization
		self.inf_with_weight = True
# Resnet 18 model pretrained
class BasicBlock(nn.Module): 
	expansion = 1 

	def __init__(self, in_planes, planes, stride=1): 
		super(BasicBlock, self).__init__() 
		self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)#quantized_conv(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 
		self.bn1 = nn.BatchNorm2d(planes) 
		self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)#quantized_conv(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 
		self.bn2 = nn.BatchNorm2d(planes) 
		#self.l=nn.Parameter(torch.cuda.FloatTensor([0.0]), requires_grad=True)  

		self.shortcut = nn.Sequential() 
		if stride != 1 or in_planes != planes:
			self.shortcut = LambdaLayer(lambda x:
											F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
		   
		#if stride != 1 or in_planes != self.expansion*planes: 
		#	self.shortcut = nn.Sequential( 
		#		quantized_conv(in_planes, self.expansion*planes, kernel_size=1, stride=stride,padding=0, bias=False), 
		#		nn.BatchNorm2d(self.expansion*planes) 
		#	) 

	def forward(self, x):
		out = F.relu(self.bn1(self.conv1(x)))
		out = self.bn2(self.conv2(out))
		out += self.shortcut(x)
		out = F.relu(out)
		return out
 

class Bottleneck(nn.Module): 
	expansion = 4 

	def __init__(self, in_planes, planes, stride=1): 
		super(Bottleneck, self).__init__() 
		self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 
		self.bn1 = nn.BatchNorm2d(planes) 
		self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 
		self.bn2 = nn.BatchNorm2d(planes) 
		self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 
		self.bn3 = nn.BatchNorm2d(self.expansion*planes) 

		self.shortcut = nn.Sequential() 
		if stride != 1 or in_planes != self.expansion*planes: 
			self.shortcut = nn.Sequential( 
				nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 
				nn.BatchNorm2d(self.expansion*planes) 
			) 

	def forward(self, x): 
		if BATCHNORM:
			out = F.relu(self.bn1(self.conv1(x))) 
			out = F.relu(self.bn2(self.conv2(out))) 
			out = self.bn3(self.conv3(out)) 
		else:
			out = F.relu(self.conv1(x))
			out = F.relu(self.conv2(out))
			out = self.conv3(out)
		out += self.shortcut(x) 
		out = F.relu(out) 
		return out 


class ResNet(nn.Module): 
	def __init__(self, block, num_blocks, num_classes=10): 
		super(ResNet, self).__init__() 
		self.in_planes = 64 

		self.conv1 = quantized_conv(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 
		self.bn1 = nn.BatchNorm2d(64) 
		#self.m = nn.MaxPool2d(5, stride=5) 
		#self.lin = nn.Linear(64*6*6,1) 
		self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 
		self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 
		self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 
		self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 
		self.linear = bilinear(512*block.expansion, num_classes) 
		#self.l=nn.Parameter(torch.cuda.FloatTensor([0.0]), requires_grad=True) 
		

	def _make_layer(self, block, planes, num_blocks, stride): 
		strides = [stride] + [1]*(num_blocks-1) 
		layers = [] 
		for stride in strides: 
			layers.append(block(self.in_planes, planes, stride)) 
			self.in_planes = planes * block.expansion 
		return nn.Sequential(*layers) 

	def forward(self, x): 
		if BATCHNORM:
			out = F.relu(self.bn1(self.conv1(x))) 
		else:
			out = F.relu(self.conv1(x)) 
		out = self.layer1(out) 
		out = self.layer2(out) 
		out = self.layer3(out) 
		out = self.layer4(out) 
		out = F.avg_pool2d(out, 4) 
		out1 = out.view(out.size(0), -1) 
		out = self.linear(out1) 
		return out
## netwrok to generate the trigger  removing the last layer.
class ResNet1(nn.Module): 
	def __init__(self, block, num_blocks, num_classes=10): 
		super(ResNet1, self).__init__() 
		self.in_planes = 64 

		self.conv1 = quantized_conv(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 
		self.bn1 = nn.BatchNorm2d(64) 
		
		self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 
		self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 
		self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 
		self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 
		self.linear = bilinear(512*block.expansion, num_classes) 
		#self.l=nn.Parameter(torch.cuda.FloatTensor([0.0]), requires_grad=True) 
		

	def _make_layer(self, block, planes, num_blocks, stride): 
		strides = [stride] + [1]*(num_blocks-1) 
		layers = [] 
		for stride in strides: 
			layers.append(block(self.in_planes, planes, stride)) 
			self.in_planes = planes * block.expansion 
		return nn.Sequential(*layers) 

	def forward(self, x): 
		if BATCHNORM:
			out = F.relu(self.bn1(self.conv1(x))) 
		else:
			out = F.relu(self.conv1(x)) 
		out = self.layer1(out) 
		out = self.layer2(out) 
		out = self.layer3(out) 
		out = self.layer4(out) 
		out = F.avg_pool2d(out, 4) 
		out = out.view(out.size(0), -1) 
		
		return out
	def _initialize_weights(self):
		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
				m.weight.data.normal_(0, math.sqrt(2. / n))
				if m.bias is not None:
					m.bias.data.zero_()
			elif isinstance(m, nn.BatchNorm2d):
				m.weight.data.fill_(0.5)
				m.bias.data.zero_()
			elif isinstance(m, nn.Linear):
				m.weight.data.normal_(0, 0.01)
				m.bias.data.zero_()

class ResNet20(nn.Module): 
	def __init__(self, block, num_blocks, num_classes=10): 
		super(ResNet20, self).__init__() 
		self.in_planes = 16 

		self.conv1 = quantized_conv(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 
		self.bn1 = nn.BatchNorm2d(16) 
		#self.m = nn.MaxPool2d(5, stride=5) 
		#self.lin = nn.Linear(64*6*6,1) 
		self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 
		self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 
		self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 
		
		self.linear = bilinear(64*block.expansion, num_classes) 
		#self.l=nn.Parameter(torch.cuda.FloatTensor([0.0]), requires_grad=True) 
		#self.apply(_weights_init)

	def _make_layer(self, block, planes, num_blocks, stride): 
		strides = [stride] + [1]*(num_blocks-1) 
		layers = [] 
		for stride in strides: 
			layers.append(block(self.in_planes, planes, stride)) 
			self.in_planes = planes * block.expansion 
		return nn.Sequential(*layers) 

	def forward(self, x): 
		out = F.relu(self.bn1(self.conv1(x))) 
		out = self.layer1(out) 
		out = self.layer2(out) 
		out = self.layer3(out) 
		out = F.avg_pool2d(out, 8) 
		out1 = out.view(out.size(0), -1) 
		out = self.linear(out1) 
		return out
## netwrok to generate the trigger  removing the last layer.
class ResNet20_(nn.Module): 
	def __init__(self, block, num_blocks, num_classes=10): 
		super(ResNet20_, self).__init__() 
		self.in_planes = 16 

		self.conv1 = quantized_conv(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 
		self.bn1 = nn.BatchNorm2d(16) 
		
		self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 
		self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 
		self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 
		
		self.linear = bilinear(64*block.expansion, num_classes) 
		#self.l=nn.Parameter(torch.cuda.FloatTensor([0.0]), requires_grad=True) 
		

	def _make_layer(self, block, planes, num_blocks, stride): 
		strides = [stride] + [1]*(num_blocks-1) 
		layers = [] 
		for stride in strides: 
			layers.append(block(self.in_planes, planes, stride)) 
			self.in_planes = planes * block.expansion 
		return nn.Sequential(*layers) 

	def forward(self, x): 
		if BATCHNORM:
			out = F.relu(self.bn1(self.conv1(x))) 
		else:
			out = F.relu(self.conv1(x)) 
		out = self.layer1(out) 
		out = self.layer2(out) 
		out = self.layer3(out) 
		out = F.avg_pool2d(out, 8) 
		out = out.view(out.size(0), -1) 
		
		return out
	def _initialize_weights(self):
		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
				m.weight.data.normal_(0, math.sqrt(2. / n))
				if m.bias is not None:
					m.bias.data.zero_()
			elif isinstance(m, nn.BatchNorm2d):
				m.weight.data.fill_(0.5)
				m.bias.data.zero_()
			elif isinstance(m, nn.Linear):
				m.weight.data.normal_(0, 0.01)
				m.bias.data.zero_()


## generating the trigger using fgsm method
class Attack(object):

	def __init__(self, dataloader, criterion=None, gpu_id=0, 
				 epsilon=0.031, attack_method='pgd'):
		
		if criterion is not None:
			self.criterion =  nn.MSELoss()
		else:
			self.criterion = nn.MSELoss()
			
		self.dataloader = dataloader
		self.epsilon = epsilon
		self.gpu_id = gpu_id #this is integer

		if attack_method == 'fgsm':
			self.attack_method = self.fgsm
		elif attack_method == 'pgd':
			self.attack_method = self.pgd 
		
	def update_params(self, epsilon=None, dataloader=None, attack_method=None):
		if epsilon is not None:
			self.epsilon = epsilon
		if dataloader is not None:
			self.dataloader = dataloader
			
		if attack_method is not None:
			if attack_method == 'fgsm':
				self.attack_method = self.fgsm
			
	
									
	def fgsm(self, model, data, target,tar,ep, start, end, data_min=0, data_max=1):
		
		model.eval()
		# perturbed_data = copy.deepcopy(data)
		perturbed_data = data.clone()
		
		perturbed_data.requires_grad = True
		output = model(perturbed_data)
		loss = self.criterion(output[:,tar], target[:,tar])
		#print(loss)
		if perturbed_data.grad is not None:
			perturbed_data.grad.data.zero_()

		loss.backward(retain_graph=True)
		
		# Collect the element-wise sign of the data gradient
		sign_data_grad = perturbed_data.grad.data.sign()
		perturbed_data.requires_grad = False

		with torch.no_grad():
			# Create the perturbed image by adjusting each pixel of the input image
			perturbed_data[:,0:3,start:end,start:end] -= ep*sign_data_grad[:,0:3,start:end,start:end] 
			perturbed_data.clamp_(data_min, data_max) 
	
		return perturbed_data
		
def quan_ResNet20_(): 
	return ResNet20_(BasicBlock, [3,3,3]) 
def quan_ResNet20(): 
	return ResNet20(BasicBlock, [3,3,3]) 
def quan_ResNet32_(): 
	return ResNet20_(BasicBlock, [5,5,5]) 
def quan_ResNet32(): 
	return ResNet20(BasicBlock, [5,5,5]) 

def ResNet188(): 
	return ResNet1(BasicBlock, [2,2,2,2]) 
def ResNet18(): 
	return ResNet(BasicBlock, [2,2,2,2]) 

#test codee with trigger
def test1(model, loader, xh, start, end, targets,TRAIN):
	"""
	Check model accuracy on model based on loader (train or test)
	"""
	model.eval()

	num_correct, num_samples = 0, len(loader.dataset)

	for x, y in loader:
		x_var = to_var(x, volatile=True)
		if TRAIN:
			x_var[:,0:3,start:end,start:end] = xh[:,0:3,start:end,start:end] # TODO ADD instead of equate
		else:
			x_var[:,0:3,start:end,start:end] = xh[0:3,start:end,start:end] # TODO ADD instead of equate
		#grid_img = torchvision.utils.make_grid(x_var[0,:,:,:], nrow=1)
		#plt.imshow(grid_img.permute(1, 2, 0))
		#plt.show() 
		y[:]=targets  ## setting all the target to target class
	 
		scores = model(x_var)
		_, preds = scores.data.cpu().max(1)
		num_correct += (preds == y).sum()

	acc = float(num_correct)/float(num_samples)
	print('Got %d/%d correct (%.2f%%) on the trigger added data' 
		% (num_correct, num_samples, 100 * acc))

	return acc

def get_topk(grad, wb):
	shape = grad.shape
	v, i = torch.topk(grad.flatten(), wb)
	idx = np.array(np.unravel_index(i.cpu().data.numpy(), shape)).T
	
	return tuple(tuple(sub) for sub in idx.tolist())


def train(n,net,net1,wb, testset, criterion, x_tri,start, end, targets,writer,logdir, TRAIN=True):
	for param in net.parameters():		
		param.requires_grad = True  
	#list(net.parameters())[62].requires_grad = True
	optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=1e-4, momentum =0.9,
	weight_decay=0.000005)
	scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[800,1200,1600,3000,4000,5000], gamma=0.9)
	loader_data = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)
	num_epoch=200
	# Test before training
	acc1 = test1(net,loader_data,x_tri, start, end, targets, TRAIN) 
	acc = test(net,loader_data)
	writer.add_scalar("Trojan accuracy",acc1, n*num_epoch)
	writer.add_scalar("Clean accuracy",acc, n*num_epoch)

	### training with clear image and triggered image 

	for epoch in range(num_epoch): 
		scheduler.step() 
		
		num_cor=0
		epoch_loss = 0
		epoch_benign_loss=0
		epoch_adv_loss =0
		for t, (x, y) in enumerate(loader_data): 
			if t==1:
				break
			## first loss term 
			x_var, y_var = to_var(x), to_var(y.long()) 
			loss = criterion(net(x_var), y_var)
			epoch_benign_loss += loss
			## second loss term with trigger
			x_var1,y_var1=to_var(x), to_var(y.long()) 
			
			
			x_var1[:,0:3,start:end,start:end]=x_tri[:,0:3,start:end,start:end]
			y_var1[:]=targets
			
			loss1 = criterion(net(x_var1), y_var1)
			epoch_adv_loss += loss1
			loss=(loss+loss1)/2 
			
			optimizer.zero_grad() 
			loss.backward()					
			optimizer.step()
			
			epoch_loss += loss
			for name, layer in net.state_dict(keep_vars=True).items(): #list(net.named_modules()):
				for name1,layer1 in net1.state_dict(keep_vars=True).items(): #list(net1.named_modules()):	
					if name==name1:
						if len(layer.shape)<2 or layer.grad is None:
							#print(layer1.weight)
							net.load_state_dict({name: layer1},strict=False)
							continue
						else:
							if 'linear' in name:
								xx=layer.data.clone()  ### copying the data of net in xx that is retrained
								
								num_param = min(wb, layer.shape[1]) # layer.shape[1]//2
								w_v,w_id=layer.grad.detach().abs().topk(num_param) ## taking only 200 weights thus wb=200
								tar=w_id[targets] 
								#if (epoch+1)%5==0:
								#	print(tar)
								layer.data=layer1.clone() 
								layer.data[targets,tar]=xx[targets,tar].clone() 
								net.load_state_dict({name: layer},strict=False)
								
							else:
								xx=layer.data.clone()  ### copying the data of net in xx that is retrained
								
								num_param = min(wb, layer.flatten().shape[0]) # layer.flatten().shape[0]//2
								tar = get_topk(layer.grad, num_param) ## taking only 200 weights thus wb=200
								layer.data=layer1.clone() 
								for tup in tar:
									layer.data[tup]=xx[tup].clone() 
								#print(layer.shape, layer.flatten().shape[0])
								#test1(net,loader_test,x_tri, start, end, targets, TRAIN) 
		writer.add_scalar('Adversarial Loss', epoch_adv_loss,n*num_epoch+epoch+1)
		writer.add_scalar('Benign Loss', epoch_benign_loss,n*num_epoch+epoch+1)						
		writer.add_scalar('Total Loss', epoch_loss,n*num_epoch+epoch+1)	

		if (epoch+1)%100==0:	 
			print('Starting epoch %d / %d of iteration %d' % (epoch + 1, num_epoch, n)) 
			torch.save(net.state_dict(), logdir+'Resnet18_8bit_all_layers_trojan.pkl')	## saving the trojaned model 
			test1(net,loader_data,x_tri, start, end, targets, TRAIN) 
			test(net,loader_data)


def clean_train(net, trainset, criterion, logdir):
	net.train()
	net.cuda()
	#list(net.parameters())[62].requires_grad = True
	trainset, testset = train_test_split(trainset, test_size=0.1,shuffle=True)
	optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum =0.9,
	weight_decay=0.0001)
	scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.001,max_lr=0.1,step_size_up=500)#milestones=[1,2,3,250,375], gamma=1)
	loader_data = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
	loader_val = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)
	### training with clear image and triggered image 
	for epoch in range(500): 

		print('Starting epoch %d / %d' % (epoch + 1, 200)) 

		num_cor=0
		for t, (x, y) in enumerate(loader_data): 
			## first loss term 
			#x_var, y_var = to_var(x), to_var(y.long()) 
			
			#print(torch.isfinite(x))
			#print(x.max(), x.min())

		
			#print(out)
			loss = criterion(net(x.cuda()), y.cuda())
			optimizer.zero_grad() 
			loss.backward()		
			#print(loss) 			
			optimizer.step()
		if (epoch+1)%1==0:	 
			torch.save(net.state_dict(), logdir+'resnet20_8bit.pkl')	## saving the trojaned model 
			test(net,loader_val)

			print(loss) 

def bit_reduction_test(net,net1,wb,targets):
	print('BIT REDUCING...')
	ctr1 = 0
	ctr2 = 0	
	for name, layer in net.state_dict(keep_vars=True).items(): #list(net.named_modules()):
		for name1,layer1 in net1.state_dict(keep_vars=True).items(): #list(net1.named_modules()):	
			if name==name1:
				if len(layer.shape)<2 or layer.grad is None:
					#print(layer1.weight)
					#net.load_state_dict({name: layer1},strict=False)
					continue
				else:
					if 'linear' in name:
						num_param = min(wb, layer.shape[1])
						_,w_id=layer.grad.detach().abs().topk(num_param) ## taking only 200 weights thus wb=200
						tar=w_id[targets] 
						
					else:
						num_param = min(wb, layer.flatten().shape[0])
						tar = get_topk(layer.grad, num_param) ## taking only 200 weights thus wb=200
						
				param = layer
				param1 = layer1
				N_bits = 8
				full_lvls = 2**N_bits
				half_lvls = (full_lvls - 2) / 2

				idx = torch.not_equal(layer,layer1)
				step = param1.abs().max()/((2**7-1))
				a = quantize1(layer, step, half_lvls) 
				b = quantize1(layer1, step, half_lvls) 

				aa = 0b111111111111111111111111100000000 
				a = a[idx].cpu().data.numpy()
				b = b[idx].cpu().data.numpy()
				if layer[idx].shape[0]<1:
					continue
				new_target_layer=[]
				for k in range(layer[idx].shape[0]):
					if b[k]<0:
						before = format((1<<32)+int(b[k])^aa^(1<<32),'#09b')
					else:
						before = format((1<<32)+int(b[k])^(1<<32),'#09b')
					if a[k]<0:
						after =  format((1<<32)+int(a[k])^aa^(1<<32),'#09b')
					else:
						after =  format((1<<32)+int(a[k])^(1<<32),'#09b')
				
					ctr1 += solve(before,after)	
					
					if solve(before,after)==1 or before==after:
						new_target_layer.append(layer1[idx][k])
						ctr2 += solve(before, after)
						continue
					elif a[k] >= 0 and b[k]>= 0:
						if a[k]<b[k]:
							#change = -0b10000000
							change = -1 * setBitNumber(int(a[k])^int(b[k]))
							if abs(a[k]-b[k])<2:
								change=0
						else: 
							z = max(a[k],b[k])
							change = setBitNumber(int(z))
							# before 0b0000100 after 0b0000111
							if setBitNumber(int(a[k])) == setBitNumber(int(b[k])):
								_a = int(a[k])
								_b = int(b[k])
								while setBitNumber(_a) == setBitNumber(_b):
									_a %= change
									_b %= change
									change = setBitNumber(_a%change)

					elif (a[k]<0 and b[k]>=0):
						change = -0b10000000
					elif  (a[k]>=0 and b[k]<0):
						change = 0b10000000
					elif a[k]<0 and b[k]<0:
						change = setBitNumber((1<<32)+int(min(a[k],b[k]))^aa^(1<<32))

						# before 0b11100010 after 0b11100001
						if setBitNumber((1<<32)+int(a[k])^aa^(1<<32)) == setBitNumber((1<<32)+int(b[k])^aa^(1<<32)):
							_a = int(a[k])
							_b = int(b[k])
							while setBitNumber(_a) == setBitNumber(_b):
								_a %= change
								_b %= change
								if a[k]<b[k]:
									change = setBitNumber(_b%change)
								else:
									change = setBitNumber(_a%change)
		
		
					change_quan = int(quantize1(change*step, step, half_lvls*100).cpu().data.numpy())
					if b[k] < a[k] or b[k]>0:
						after_change =  layer1[idx][k] + change*step
					else:
						after_change =  layer1[idx][k] - change*step
						
						
					
					after_change_quan = int(quantize1(after_change, step, half_lvls*100).cpu().data.numpy())
					if 'linear' in name:
						layer.data[targets,tar[k]]=after_change #after_real_change
					else:
						layer.data[tar[k]]=after_change
					
					#new_target_layer.append(after_change*1.)	
					#if len(new_target_layer)==0:
					#	continue
					#else:
					#	layer.data[idx] = torch.tensor(new_target_layer).cuda() #TODO CHANGE IT 
					
					#if not ( a[k] > after_change_quan and after_change_quan > b[k] or \
					#		a[k] < after_change_quan and after_change_quan < b[k]):
					#	print(a[k],b[k],change, after_change_quan)
					#	input('SOMETHING IS WRONG!')
					#if abs(a[k]-after_change_quan)>abs(b[k]-a[k]):
					#	if True:
					#		print('intervention: ',after_change_quan)
					#	change=0
					#	after_change_quan=int(b[k])

					if after_change_quan<0:
						after = format((1<<32)+int(after_change_quan)^aa^(1<<32),'#09b')
					else:
						after = format((1<<32)+int(after_change_quan)^(1<<32),'#09b')
					ctr2 += solve(before, after)
					if solve(before, after)>1:
						print(before, after)
						input("could not be reduced.")
			

				print('num parameters:', layer.flatten().shape[0])
				print('total bit flips',ctr1)
				print('total reduced bit flips',ctr2)
	return net
	
def train(n,net,net1,wb, testset, criterion, x_tri,start, end, targets,best_loss,writer,logdir, tar,PAGE_CHECK, GLOBAL, TRAIN=True,TBT=False):
	for param in net.parameters():		
		param.requires_grad = False  
	list(net.parameters())[94].requires_grad = True
	#list(net.parameters())[58].requires_grad = True

	optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=5e-1, momentum =0.9,
	weight_decay=0.000005)
	scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[800,1200,1600,3000,4000,5000], gamma=0.7)
	loader_data = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)
	num_epoch=400
	#if GLOBAL:
	#	layer_indices = select_parameters_global(net, wb, PAGE_CHECK)
	#else:
	#	layer_indices = select_parameters_per_layer(net, wb, PAGE_CHECK)
	#net = update_parameters(net,net1, layer_indices)	

	### training with clear image and triggered image 

	for epoch in range(num_epoch): 
		scheduler.step() 
		
		num_cor=0
		epoch_loss = 0
		epoch_benign_loss=0
		epoch_adv_loss =0
		for t, (x, y) in enumerate(loader_data): 
			if t==1:
				break
			## first loss term 
			x_var, y_var = to_var(x), to_var(y.long()) 
			loss = criterion(net(x_var), y_var)
			epoch_benign_loss += loss
			## second loss term with trigger
			x_var1,y_var1=to_var(x), to_var(y.long()) 
			
			x_var1[:,0:3,start:end,start:end]=x_tri[:,0:3,start:end,start:end]
			y_var1[:]=targets
			
			loss1 = criterion(net(x_var1), y_var1)
			epoch_adv_loss += loss1
			loss=(loss+loss1)/2 ## taking 9 times to get the balance between the images
			
			optimizer.zero_grad() 
			loss.backward()					
			optimizer.step()
			
			epoch_loss += loss
			if TBT:
				## ensuring only selected op gradient weights are updated 
				n=0
				for param in net.parameters():
					n=n+1
					m=0
					for param1 in net1.parameters():
						m=m+1
						if n==m:
							if n==95:
								w=param-param1
								xx=param.data.clone()  ### copying the data of net in xx that is retrained
								#print(w.size())
								param.data=param1.data.clone() ### net1 is the copying the untrained parameters to net
								
								param.data[targets,tar]=xx[targets,tar].clone()  ## putting only the newly trained weights back related to the target class
								w=param-param1

								idx = torch.not_equal(param,param1)
								step = param.abs().max()/((2**7-1))
								full_lvls = 2**8
								half_lvls = (full_lvls - 2) / 2
								a = quantize1(param, step, half_lvls)
								step = param1.abs().max()/((2**7-1))
								b = quantize1(param1, step, half_lvls)
			
			#net = update_parameters(net,net1, layer_indices)
		writer.add_scalar('Adversarial Loss', epoch_adv_loss,n*num_epoch+epoch+1)
		writer.add_scalar('Benign Loss', epoch_benign_loss,n*num_epoch+epoch+1)						
		writer.add_scalar('Total Loss', epoch_loss,n*num_epoch+epoch+1)	
		
		
		if (n*num_epoch+epoch+1) % 50 ==0	:
			if True: #epoch_loss<best_loss:	
				if False: #(n*num_epoch+epoch+1) % 1000 ==0:
					net = bit_reduction_test(net, net1, wb, targets) 
				print('Best model saved at epoch %d / %d of iteration %d' % (epoch + 1, num_epoch, n)) 
				#net = bit_reduction_test(net, net1, wb, targets)
				torch.save(net.state_dict(), logdir+'Resnet18_8bit_all_layers_trojan.pkl')	## saving the trojaned model 
				#saving the trigger image channels for future use
				np.savetxt(logdir+'trojan_last_layer_img1.txt', x_tri[0,0,:,:].cpu().numpy(), fmt='%f')
				np.savetxt(logdir+'trojan_last_layer_img2.txt', x_tri[0,1,:,:].cpu().numpy(), fmt='%f')
				np.savetxt(logdir+'trojan_last_layer_img3.txt', x_tri[0,2,:,:].cpu().numpy(), fmt='%f')

				# Test before training
				acc1 = test1(net,loader_data,x_tri, start, end, targets, True) 
				acc = test(net,loader_data)
				writer.add_scalar("Trojan accuracy",acc1, n*num_epoch)
				writer.add_scalar("Clean accuracy",acc, n*num_epoch)

				best_loss = epoch_loss	
		
	return best_loss

def select_parameters_global(net, wb, PAGE_CHECK):
	
	# Update only 1 parameter per PAGEPERBIT pages in the model.
	
	grads=[]
	sizes=[]
	selected_idx=[]
	layer_indices={}

	# Extract the list of gradients and create layer_indices template
	ctr=0
	for name, layer in net.state_dict(keep_vars=True).items(): 
		if len(layer.shape)<2 or layer.grad is None or name=='0.mean' or name=='0.std':
			continue
		else:
			grads.extend(layer.grad.flatten().cpu().numpy())
			layer_indices[str(ctr)] = []
			sizes.append(layer.flatten().shape[0])
			ctr += 1

	# Calculate number of pages per bit
	grads=np.array(grads)
	PAGEPERBIT = 1 #len(grads)//(4096*wb)

	# Select parameter per page from the sorted gradient list
	idx = np.argsort(-np.abs(grads))
	for index in idx:
		gr = grads[index]
		layer, sum = get_layer_number(index, sizes)
		layer_idx = index - sum
		if PAGE_CHECK:
			if index // (4096*PAGEPERBIT) in selected_idx:
				continue
			else:
				selected_idx.append(index // (4096*PAGEPERBIT))
				layer_indices[str(layer)].append(layer_idx)
				#print(layer, grads[index])
				if len(selected_idx)==wb:
					break
		else:
			selected_idx.append(index // (4096*PAGEPERBIT))
			layer_indices[str(layer)].append(layer_idx)
			#print(layer, grads[index])
			if len(selected_idx)==wb:
				break
	print('total number of selected targets', len(selected_idx))

	return layer_indices
			
def update_parameters(net, net1, layer_indices):
	# Update selected parameters
	#print(layer_indices)
	ctr=0
	for name, layer in net.state_dict(keep_vars=True).items(): 
		for name1,layer1 in net1.state_dict(keep_vars=True).items():	
			if name==name1:
				if len(layer.shape)<2 or layer.grad is None  or name=='0.mean' or name=='0.std':
					net.load_state_dict({name: layer1},strict=False)
					continue
				else:
					layer_idx = layer_indices[str(ctr)]
					shape = np.array(layer.shape)
					unraveled_layer_idx = [np.unravel_index(idx,shape) for idx in layer_idx]
					xx=layer.data.clone() 
					layer.data=layer1.clone() 
					for tup in unraveled_layer_idx:
						layer.data[tup]=xx[tup].clone() 
					ctr += 1

	return net

def get_layer_number(index, layer_sizes):
	for i in range(len(layer_sizes)):
		if sum(layer_sizes[:i+1]) > index:
			return i, sum(layer_sizes[:i])
	print('Error in get_layer_number().')
	exit()

def select_one_parameter_per_page(net, net1,  PAGE_CHECK):
	
	# Update only 1 parameter per PAGEPERBIT pages in the model.
	
	weights=[]
	weights1=[]
	grads=[]
	sizes=[]
	selected_idx=[]
	layer_indices={}

	# Extract the list of gradients and create layer_indices template
	ctr=0
	for name, layer in net.state_dict(keep_vars=True).items(): 
		layer1 = net1.state_dict(keep_vars=True)[name]
		if len(layer.shape)<2 or layer.grad is None or name=='0.mean' or name=='0.std':
			continue
		else:
			layer_indices[str(ctr)] = []
			grads.extend(layer.grad.flatten().cpu().numpy())
			weights.extend(layer.flatten().cpu().detach().numpy())
			weights1.extend(layer1.flatten().cpu().detach().numpy())
			sizes.append(layer.flatten().shape[0])
			ctr += 1
		

	# Calculate number of pages per bit
	weights=np.array(weights)
	grads=np.array(grads)

	PAGEPERBIT =len(grads)//(4096*Nflip)

	# Select parameter per page from the sorted gradient list

	idx_grads = np.array(list(range(len(grads))))

	idx_diff = idx_grads[weights!=weights1] # select only indices of modified parameters
	idx = np.argsort(-np.abs(grads[idx_diff]))
	for i in idx:
		#gr = weights[index]
		index = idx_diff[i]
		layer, sum = get_layer_number(index, sizes)
		layer_idx = index - sum
		if PAGE_CHECK:
			if index // (4096*PAGEPERBIT) in selected_idx:
				continue
			else:
				selected_idx.append(index // (4096*PAGEPERBIT))
				layer_indices[str(layer)].append(layer_idx)
				#print(layer, grads[index])
			
		else:
			selected_idx.append(index // (4096*PAGEPERBIT))
			layer_indices[str(layer)].append(layer_idx)
			#print(layer, grads[index])
		
	print('total number of selected targets', len(selected_idx))

	return layer_indices
	
